package cc.unitmesh.cf

import io.kotest.matchers.shouldBe
import org.junit.jupiter.api.Test
import kotlin.test.Ignore

class STSemanticTest {
    @Test
    @Ignore
    fun test_for_encode_decode() {
        val semantic = STSemantic.create()
        val embedding = semantic.getTokenizer().encode("blog")

        embedding.ids shouldBe listOf(101L, 9927L, 102L)

        embedding.attentionMask shouldBe listOf(1L, 1L, 1L)

        val text = semantic.getTokenizer().decode(embedding.ids)
        text shouldBe "[CLS] blog [SEP]"
    }

    @Test
    @Ignore
    fun embed_demo() {
        val semantic = STSemantic.create()
        val embedding = semantic.embed("blog").map {
            it.toFloat()
        }

        val rustVersionOutput = listOf(
            -0.110279165,
            -0.20323151,
            -0.19402212,
            0.5687035,
            0.23401435,
            -0.3581861,
            0.5271091,
            0.15083969,
            0.34416556,
            0.56494397,
            0.22725289,
            0.113635994,
            0.13297962,
            0.16670649,
            0.18186933,
            -0.06564712,
            -0.15182413,
            0.355111,
            -0.7163484,
            -0.04913767,
            -0.663516,
            0.22934864,
            -0.004983654,
            0.48360667,
            -0.05824821,
            0.22712882,
            -0.5377796,
            -0.12723802,
            -0.13869843,
            -0.65868527,
            -0.09670005,
            -0.06010717,
            -0.34063816,
            -0.028756538,
            0.20640771,
            -0.02175857,
            0.060673315,
            -0.16044103,
            0.08249193,
            -0.2743876,
            0.013140768,
            -0.57382923,
            0.09609735,
            -0.011499226,
            -0.008267462,
            -0.2613307,
            0.18937238,
            -0.12832402,
            -0.17780983,
            0.23589806,
            -0.30812678,
            -0.19349045,
            -0.1199696,
            -0.66233987,
            -0.03363161,
            -0.09512878,
            -0.42250982,
            0.4612564,
            -0.03416545,
            -0.59481907,
            0.7756756,
            -0.17171194,
            -0.74515057,
            0.48214105,
            0.109512895,
            -0.060159165,
            0.03312962,
            0.8364305,
            0.5005641,
            -0.383744,
            -0.24040401,
            0.46697685,
            -0.16124786,
            0.47246695,
            0.31220987,
            -0.07845653,
            0.48405328,
            0.0978297,
            0.39955565,
            -0.38739052,
            0.20320205,
            0.37052405,
            -0.3384258,
            0.40567085,
            0.13302103,
            -0.5924873,
            0.6553878,
            0.2492858,
            -0.2113901,
            0.039609537,
            0.09936675,
            0.011963151,
            0.48122612,
            -0.087005295,
            -0.855067,
            0.05195448,
            0.21967377,
            -0.26883313,
            -0.28816518,
            1.5938386,
            0.016555792,
            0.4605228,
            0.3085129,
            0.2848029,
            0.21296711,
            0.09154507,
            -0.38858747,
            0.8437827,
            0.16395916,
            0.1798058,
            -0.14895085,
            0.40634096,
            -0.37596753,
            -0.30977765,
            0.53710127,
            0.04905224,
            0.59237546,
            -0.08046392,
            0.28081483,
            0.1994078,
            -0.21207093,
            0.15927452,
            -0.2675248,
            -0.28623047,
            -0.18719508,
            -0.271465,
            0.08163627,
            -1.7481325E-32,
            0.7117869,
            0.046314757,
            0.08810749,
            0.15848,
            0.5094527,
            0.5384492,
            -0.37314346,
            -0.25140703,
            -0.46735236,
            0.23707624,
            -0.03525045,
            0.04242001,
            -0.014994323,
            0.32990482,
            0.6873948,
            0.1020777,
            -0.22182953,
            0.50239617,
            0.41482937,
            -0.4894677,
            -0.0010047754,
            0.11522704,
            0.4444447,
            0.8956523,
            0.112065874,
            -0.49681196,
            0.27852637,
            -0.4289411,
            -0.35519376,
            0.16644251,
            0.005235652,
            -0.14812987,
            -0.3921808,
            -0.2539396,
            -0.058002263,
            -0.5353127,
            0.44027257,
            -0.5862116,
            0.48723444,
            0.041912634,
            -0.39615956,
            -0.031653907,
            -0.22384675,
            -0.1660247,
            0.47698462,
            0.4447286,
            0.08870772,
            0.024380995,
            -0.16440527,
            0.040698666,
            0.25429553,
            -0.4433404,
            0.12491482,
            -0.024511186,
            -0.1758517,
            0.005671402,
            0.1728404,
            -0.55925035,
            0.7056008,
            0.05783311,
            0.507005,
            0.18514593,
            -0.2557012,
            -0.49320444,
            -0.21200551,
            0.118768364,
            0.1765864,
            -0.051282704,
            0.54867935,
            -0.04089139,
            -0.32395157,
            -0.2656132,
            0.10571144,
            -0.10910803,
            -0.28008866,
            0.10082736,
            -0.46323022,
            -0.41716757,
            -0.8656499,
            0.30254444,
            0.14473222,
            -0.31123692,
            -0.4305289,
            0.29052138,
            0.05334143,
            -0.48228803,
            -0.13061367,
            -0.27647334,
            0.14793135,
            0.18344998,
            -0.24634238,
            0.38107178,
            0.64464915,
            -0.23734172,
            -0.048561174,
            2.568429E-32,
            -0.37187788,
            -0.2801578,
            -0.4365394,
            0.6523815,
            -0.07841512,
            0.057543516,
            0.08467019,
            -0.006827136,
            -0.04405451,
            0.5664455,
            -0.26514426,
            -0.47978082,
            0.1894442,
            0.579729,
            -0.0056093237,
            0.3913045,
            0.114259414,
            -0.6279835,
            -0.6148898,
            0.5018132,
            -0.36704504,
            -0.2511385,
            -0.787425,
            0.1256488,
            0.46383405,
            0.22579008,
            0.34901974,
            0.37562713,
            -0.3048062,
            -0.43956247,
            0.61040634,
            -0.23435688,
            -0.3910474,
            -0.03822125,
            0.0010235012,
            0.7031872,
            0.07300221,
            0.1711,
            -0.03805649,
            -0.20503116,
            0.31147385,
            -0.4113067,
            0.2830288,
            0.81740123,
            -0.4922081,
            0.047529582,
            -0.6890304,
            0.32192668,
            0.036708724,
            0.20717506,
            -0.60736126,
            -0.40349212,
            0.35182294,
            -0.31327686,
            0.027656319,
            -0.09882919,
            -0.46388862,
            -0.1835417,
            -0.37170517,
            0.3430324,
            0.087905385,
            0.39683786,
            -0.24925108,
            0.5404869,
            -0.20595808,
            -0.6918924,
            0.12369304,
            0.136187,
            -0.43145856,
            0.17041536,
            0.93080807,
            0.34820393,
            0.28278932,
            -0.048433464,
            -0.0042013726,
            0.36461926,
            -0.23178981,
            0.52015966,
            -0.23646581,
            -0.2645026,
            -0.07448574,
            -0.19077025,
            0.25949025,
            0.05321908,
            -0.1630237,
            -0.40390202,
            0.27623382,
            0.23623504,
            -0.07707992,
            -0.44474483,
            -0.08484307,
            -0.30488634,
            -0.31630877,
            0.5552085,
            0.26741418,
            -8.796744E-8,
            -0.42388526,
            -0.7835722,
            0.31339005,
            0.18535687,
            0.4912046,
            0.3799179,
            0.46367088,
            0.22452867,
            0.06597305,
            0.6768408,
            -0.0422195,
            -0.31624988,
            0.15983607,
            0.5498157,
            0.10440541,
            -0.68329877,
            0.20370972,
            -0.7067054,
            -0.083996125,
            -0.1465518,
            -0.04802907,
            -0.082606874,
            -0.091461994,
            -0.1620254,
            0.11281832,
            -0.39109156,
            0.052334744,
            0.45771384,
            -0.14178409,
            0.12591025,
            -0.07683225,
            0.20745273,
            -0.10680809,
            -0.07593716,
            0.37663043,
            -0.2506283,
            -0.0801289,
            -0.34662268,
            -0.28494385,
            -0.7509082,
            0.07219892,
            0.04742825,
            0.97876126,
            -0.35832188,
            -0.38610053,
            -0.1257224,
            0.05736194,
            -0.42519012,
            -0.014331664,
            -0.2473176,
            0.043123465,
            0.022188524,
            0.9511275,
            0.6504857,
            0.07993839,
            -0.033392675,
            0.017061293,
            0.17677112,
            -0.18121076,
            0.4190925,
            1.0514423,
            0.1513483,
            0.009485464,
            0.099742986
        )
            .map { it.toFloat() }
        embedding.size shouldBe 384
        embedding shouldBe rustVersionOutput

        val embedding2 = semantic.embed("what a wonderful day")
            .map {
                it.toFloat()
            }
        val rustOutput2 = listOf(
            -0.23757131,
            0.42181793,
            0.22047941,
            -0.18726386,
            -0.24583507,
            -0.25108197,
            0.6374419,
            0.085151315,
            0.1442243,
            -0.070295855,
            0.008653805,
            0.43911538,
            0.055730313,
            0.10676991,
            0.26014924,
            0.2896713,
            -0.2596091,
            -0.081983335,
            -0.763128,
            -0.042965587,
            0.026263474,
            0.10714874,
            -0.28731206,
            0.41199112,
            -0.69807863,
            0.091905,
            0.15566444,
            0.12686445,
            0.21448667,
            -0.24315594,
            -0.51446074,
            0.16428299,
            -0.000118116535,
            0.20191996,
            -0.101417355,
            -0.16508886,
            0.37552822,
            -0.46906987,
            0.23551714,
            -0.068285175,
            0.022380188,
            -0.28443417,
            0.25796208,
            0.07599489,
            -0.0076752505,
            -0.13593675,
            0.28115118,
            -0.14344643,
            0.34103608,
            -0.33395728,
            -0.2579864,
            0.016159425,
            -0.19800232,
            -0.13548213,
            0.14577337,
            0.7493315,
            0.21400599,
            -0.39856425,
            0.49935892,
            -0.080888204,
            -0.15620655,
            0.45772076,
            0.6132026,
            0.3952482,
            0.21080546,
            -0.22833854,
            -0.37721452,
            0.36367285,
            -0.35356832,
            -0.35547438,
            -0.58301455,
            0.017503755,
            0.48018637,
            0.2812776,
            0.17325264,
            0.25290966,
            -0.34936222,
            0.1046262,
            0.23808312,
            0.04101063,
            -0.018264974,
            -0.41155443,
            0.05262203,
            -0.0070503554,
            0.040844005,
            0.021735633,
            0.39907792,
            -0.14221595,
            0.47706604,
            -0.051643267,
            -0.4160352,
            -0.030591374,
            -0.47033653,
            0.21948542,
            -0.1964188,
            -0.26143283,
            -0.24508657,
            -0.3017739,
            -0.41032374,
            0.92093635,
            0.24501008,
            0.3010645,
            0.29330993,
            0.12946537,
            0.019708492,
            0.04740788,
            -0.3660365,
            0.26281777,
            -0.38427636,
            -0.58510345,
            0.15892166,
            0.45905378,
            0.07957403,
            0.24658777,
            0.13459058,
            0.054463577,
            -0.019414136,
            0.21921758,
            -0.13905998,
            -0.035978105,
            -0.05190766,
            -0.16323759,
            0.446703,
            0.14784926,
            -0.0782991,
            0.09661436,
            -0.03541751,
            -2.5405898e-32,
            0.47014305,
            0.15904875,
            0.50200224,
            0.2654104,
            0.33997092,
            0.12910001,
            -0.5123285,
            -0.103606224,
            -0.05372004,
            -0.23638491,
            -0.50632966,
            -0.039449897,
            -0.030364433,
            -0.23906739,
            -0.18272881,
            -0.057134774,
            -0.19576363,
            0.44058284,
            0.37159362,
            -0.011719841,
            -0.0060884506,
            0.43028045,
            0.07784771,
            0.43517303,
            0.10791526,
            0.12998937,
            0.06180446,
            -0.3601092,
            0.5932494,
            0.15001611,
            -0.48684338,
            -0.15010531,
            0.27467212,
            0.28448686,
            -0.087603174,
            -0.66473657,
            -0.46131966,
            0.24596177,
            -0.15279156,
            0.12963088,
            0.59755456,
            0.1384039,
            0.05168697,
            -0.19570132,
            -0.15571915,
            -0.3702244,
            0.29922798,
            0.38483346,
            -0.035737332,
            -0.0112374155,
            -0.08088391,
            -0.15966672,
            0.16647367,
            -0.044745505,
            -0.27767184,
            -0.2744285,
            -0.23501082,
            -0.012161096,
            0.36595383,
            0.4515806,
            0.45031846,
            0.22538064,
            -0.6230848,
            -0.71849126,
            -0.5590863,
            0.0066277967,
            0.33780327,
            -0.015154123,
            -0.21851556,
            0.18082319,
            0.26284298,
            -0.13966434,
            0.39235282,
            -0.42319036,
            -0.05361036,
            0.047728196,
            0.7335229,
            -0.19867128,
            -0.03703798,
            0.069436826,
            0.2200169,
            0.14777786,
            0.321534,
            0.08367285,
            0.26668835,
            -0.23587625,
            0.06728106,
            -0.02338601,
            -0.08651182,
            0.018419398,
            -0.42180195,
            0.0018919756,
            0.3552283,
            -0.5347795,
            0.29034477,
            2.776599e-32,
            0.5752948,
            0.17907967,
            -0.07624558,
            0.8576538,
            -0.1823696,
            -0.11323098,
            -0.14679062,
            0.5773324,
            -0.18004651,
            0.2901449,
            -0.06906002,
            -0.08325657,
            0.52954584,
            -0.0080984095,
            0.050075233,
            0.030888483,
            0.54806143,
            0.004300095,
            -0.400905,
            -0.17824252,
            -0.4217523,
            0.3763392,
            0.2345022,
            -0.3367548,
            0.22039871,
            0.18056786,
            0.23834513,
            0.16711481,
            -0.26729947,
            -0.36650363,
            -0.22120965,
            -0.1298461,
            -0.32790497,
            -0.115279,
            0.053716335,
            0.3437524,
            0.08407261,
            -0.041746933,
            -0.39772236,
            0.0042545325,
            -0.064964235,
            -0.027934246,
            -0.20807804,
            0.5595079,
            -0.095338225,
            0.23559533,
            0.061502274,
            0.44133386,
            0.26653978,
            0.039179925,
            -0.23953135,
            -0.03994049,
            -0.12543418,
            0.002512331,
            -0.31797346,
            -0.10949094,
            0.015860835,
            -0.37284505,
            -0.26002386,
            -0.056559857,
            -0.24573009,
            -0.44363555,
            -0.44228315,
            0.40245554,
            0.17621775,
            -0.19949736,
            -0.094719686,
            0.3066672,
            -0.16247351,
            0.13748024,
            0.12033085,
            0.0692291,
            -0.49683723,
            -0.3526328,
            -0.284268,
            0.0017397503,
            0.34732404,
            -0.16513069,
            -0.050970476,
            0.16590232,
            -0.054834757,
            0.1760121,
            -0.08390907,
            -0.014992371,
            -0.09376677,
            -0.23852046,
            0.08522638,
            -0.041989576,
            -0.09593543,
            0.2982405,
            0.017753601,
            0.17772289,
            -0.058010545,
            0.15282509,
            0.09495619,
            -8.90667e-8,
            0.09343702,
            0.093091846,
            -0.35675547,
            -0.101335086,
            -0.28447294,
            0.037397265,
            0.12092298,
            -0.10358507,
            -0.18559027,
            0.19251801,
            -0.010169278,
            -0.077650666,
            -0.04040389,
            0.038562614,
            0.398581,
            0.15078975,
            -0.010722627,
            -0.1369494,
            -0.23022191,
            -0.07499645,
            0.0015943846,
            -0.26432306,
            0.36633483,
            0.6141245,
            -0.01593933,
            -0.02889085,
            -0.0455574,
            0.19643791,
            -0.24832849,
            0.018728893,
            -0.19981956,
            0.29144582,
            -0.2931111,
            -0.044640984,
            -0.09011545,
            -0.19167201,
            0.27613762,
            0.061406102,
            0.18194114,
            -0.7694629,
            0.11186733,
            0.25077644,
            0.003472959,
            0.0015984625,
            -0.12014166,
            -0.37700155,
            0.19320922,
            0.13076204,
            -0.26576373,
            -0.1264124,
            0.10198319,
            0.29302177,
            0.037032265,
            0.5087324,
            0.18053712,
            -0.08459523,
            -0.079470575,
            0.19753295,
            -0.24774618,
            -0.13636033,
            0.21189696,
            0.16998023,
            -0.35827938,
            0.13053459
        )
            .map { it.toFloat() }

        embedding2.size shouldBe 384
        embedding2 shouldBe rustOutput2
    }
}